{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a9d53a0a-2c87-4eb1-b007-8710cdc0fa7d",
   "metadata": {},
   "source": [
    "# StableDiffusion模型LoRA微调训练\n",
    "\n",
    "PAI提供了丰富的，面向图像、语音、文本等多个领域，基于预训练模型的算法，能够支持用户简单便捷地完成预训练模型的微调训练。\n",
    "\n",
    "本文档将介绍如何通过PAI Python SDK调用PAI平台上提供的 `kohya_lora_trainer` 训练算法，他基于开源社区提供的`kohya_ss`构建，支持StableDiffusion模型[LoRA(Low-Rank Adaptation)](https://arxiv.org/abs/2106.09685)高效微调训练，从而获得可以生成特定领域风格图片的LoRA模型。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e4eab8b",
   "metadata": {},
   "source": [
    "\n",
    "## 费用说明\n",
    "\n",
    "本示例将会使用以下云产品，并产生相应的费用账单：\n",
    "\n",
    "- PAI-DLC：运行训练任务，详细计费说明请参考[PAI-DLC计费说明](https://help.aliyun.com/zh/pai/product-overview/billing-of-dlc)\n",
    "- OSS：存储训练任务输出的模型、训练代码、TensorBoard日志等，详细计费说明请参考[OSS计费概述](https://help.aliyun.com/zh/oss/product-overview/billing-overview)\n",
    "\n",
    "\n",
    "> 通过参与云产品免费试用，使用**指定资源机型**提交训练作业或是部署推理服务，可以免费试用PAI产品，具体请参考[PAI免费试用](https://help.aliyun.com/zh/pai/product-overview/free-quota-for-new-users)。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6db6ef2e-13b6-4a54-9dfd-0b52d1eb28fe",
   "metadata": {},
   "source": [
    "## 准备工作\n",
    "\n",
    "通过以下命令安装PAI Python SDK。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07c4dcbe-7812-4c79-8a7b-db1ce0b070ce",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -m pip install --upgrade pai"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22dbf882-e993-42a9-93b9-9dfb62d3a0b9",
   "metadata": {},
   "source": [
    "\n",
    "SDK 需要配置访问阿里云服务需要的 AccessKey，以及当前使用的工作空间和OSS Bucket。在 PAI Python SDK 安装之后，通过在 **命令行终端** 中执行以下命令，按照引导配置密钥，工作空间等信息。\n",
    "\n",
    "\n",
    "```shell\n",
    "\n",
    "# 以下命令，请在 命令行终端 中执行.\n",
    "\n",
    "python -m pai.toolkit.config\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "552916bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.session import get_default_session, setup_default_session\n",
    "\n",
    "# 用户通过以下的方式配置SDK的session.\n",
    "sess = get_default_session()\n",
    "if not sess:\n",
    "    print(\"config session\")\n",
    "    sess = setup_default_session(\n",
    "        region_id=\"<REGION_ID>\",  # 例如：cn-beijing\n",
    "        workspace_id=\"<WORKSPACE_ID>\",  # 例如：12345\n",
    "        oss_bucket_name=\"<OSS_BUCKET_NAME>\",\n",
    "    )\n",
    "# 将当前的配置持久化到 ~/.pai/config.json，SDK默认从对应的路径读取配置初始化默认session。\n",
    "sess.save_config()\n",
    "\n",
    "print(sess.workspace_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bac94bff",
   "metadata": {},
   "source": [
    "我们可以通过以下代码验证当前的配置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82259c8c-b895-4bda-b38c-b7bc2b24fc01",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pai\n",
    "from pai.session import get_default_session\n",
    "\n",
    "print(pai.__version__)\n",
    "sess = get_default_session()\n",
    "print(sess.workspace_id)\n",
    "print(sess.workspace_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdf0a67c",
   "metadata": {},
   "source": [
    "## 算法简介\n",
    "\n",
    "\n",
    "通过 `pai.estimator.AlgorithmEstimator` 对象，用户可以轻松的在 SDK 中调用 PAI 平台的已注册算法（指定 `algorithm_provider` 为 `pai` 便可以调用 PAI 平台所提供的官方算法）。\n",
    "\n",
    "\n",
    "获取`kohya_lora_trainer`算法示例代码如下：\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e54d3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pai.estimator import AlgorithmEstimator\n",
    "\n",
    "# 创建 AlgorithmEstimator 对象\n",
    "est = AlgorithmEstimator(\n",
    "    algorithm_name=\"kohya_lora_trainer\",  # 算法名称\n",
    "    algorithm_version=\"v0.1.0\",  # 算法版本\n",
    "    algorithm_provider=\"pai\",  # 算法提供者，制定为 pai 以调用 PAI 平台所提供的官方算法\n",
    "    base_job_name=\"SD_kohya_lora_training\",  # 任务名称\n",
    "    # 用户可以查看 est.supported_instance_types 来获取支持的实例类型；\n",
    "    # 用户可以根据自己的需求选择合适的实例类型，不填写则会选择默认的实例类型\n",
    "    # instance_type=\"ecs.gn6v-c8g1.2xlarge\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "471452a8",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "通过`AlgorithmEstimator.hyperparameter_definitions`，`AlgorithmEstimator.input_channel_definitions`等属性，我们可以获取算法支持的参数定义，输入定义等信息。\n",
    "\n",
    "查看`kohya_lora_trainer`训练算法支持的超参："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2dfe9fa",
   "metadata": {
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table><tr><td style=\"text-align:left;\">HyperParameter Name</td><td style=\"text-align:left;\">Type</td><td style=\"text-align:left;\">DefaultValue</td><td style=\"text-align:left;\">Required</td><td style=\"text-align:left;\">Description</td></tr><tr><td style=\"text-align:left;\">max_epochs</td><td style=\"text-align:left;\">Int</td><td style=\"text-align:left;\">10</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">Total train epochs.</td></tr><tr><td style=\"text-align:left;\">batch_size</td><td style=\"text-align:left;\">Int</td><td style=\"text-align:left;\">1</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">Train batch size per gpu.</td></tr><tr><td style=\"text-align:left;\">learning_rate</td><td style=\"text-align:left;\">Float</td><td style=\"text-align:left;\">0.0001</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">Initial learning rate.</td></tr><tr><td style=\"text-align:left;\">resolution</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">512</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">image resolution in training (\"size\" or \"width,height\")</td></tr><tr><td style=\"text-align:left;\">network_module</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">networks.lora</td><td style=\"text-align:left;\">True</td><td style=\"text-align:left;\">Network module name. Accept Values: [\"networks.lora\", \"networks.dylora\"]</td></tr><tr><td style=\"text-align:left;\">network_dim</td><td style=\"text-align:left;\">Int</td><td style=\"text-align:left;\">32</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">The number of dimensions of LoRA. The greater the number, the greater the expressive power, but the memory and time required for learning also increase. In addition, it seems that it is not good to increase it blindly..</td></tr><tr><td style=\"text-align:left;\">network_alpha</td><td style=\"text-align:left;\">Float</td><td style=\"text-align:left;\">32.0</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Alpha for LoRA weight scaling(if 1 same as network_dim for same behavior as old version)</td></tr><tr><td style=\"text-align:left;\">network_args</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\"></td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Optional, default None. Additional arguments for network, such as '\"conv_dim=16\" \"unit=4\"'</td></tr><tr><td style=\"text-align:left;\">save_every_n_epochs</td><td style=\"text-align:left;\">Int</td><td style=\"text-align:left;\">5</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Save checkpoint interval by epoch.</td></tr><tr><td style=\"text-align:left;\">max_token_length</td><td style=\"text-align:left;\">Int</td><td style=\"text-align:left;\">225</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Max token length of text encoder (default for 75, 150 or 225).</td></tr><tr><td style=\"text-align:left;\">caption_extension</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">.txt</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Extension of caption files.</td></tr><tr><td style=\"text-align:left;\">save_model_as</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">safetensors</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Format to save the model (default is .safetensors). Accept Values: [\"ckpt\", \"pt\", \"safetensors\"]\"</td></tr><tr><td style=\"text-align:left;\">mixed_precision</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">fp16</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Accept Values: [\"no\", \"fp16\", \"bf16\"]</td></tr><tr><td style=\"text-align:left;\">v2</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">false</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Use 'true' if load Stable Diffusion v2.0 model.</td></tr><tr><td style=\"text-align:left;\">clip_skip</td><td style=\"text-align:left;\">Int</td><td style=\"text-align:left;\">2</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Use output of nth layer from back of text encoder (n>=1). Specifying 2 for the clip_skip option uses the output of the next-to-last layer. If 1 or option is omitted, the last layer is used. SD2.0 uses the second layer from the back by default, so please do not specify it when learning SD2.0</td></tr><tr><td style=\"text-align:left;\">random_crop</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">true</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Enable random crop.</td></tr><tr><td style=\"text-align:left;\">flip_aug</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">false</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Enable horizontal flip augmentation.</td></tr><tr><td style=\"text-align:left;\">color_aug</td><td style=\"text-align:left;\">String</td><td style=\"text-align:left;\">false</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Enable weak color augmentation.</td></tr><tr><td style=\"text-align:left;\">gradient_accumulation_steps</td><td style=\"text-align:left;\">Int</td><td style=\"text-align:left;\">1</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">gradient accumulation steps</td></tr><tr><td style=\"text-align:left;\">noise_offset</td><td style=\"text-align:left;\">Float</td><td style=\"text-align:left;\"></td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">Enable noise offset (if enabled, around 0.1 is recommended).</td></tr><tr><td style=\"text-align:left;\">seed</td><td style=\"text-align:left;\">Int</td><td style=\"text-align:left;\">42</td><td style=\"text-align:left;\">False</td><td style=\"text-align:left;\">A seed for reproducible training.</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import HTML, display\n",
    "\n",
    "\n",
    "def table_display(data):\n",
    "    display(\n",
    "        HTML(\n",
    "            \"<table><tr>{}</tr></table>\".format(\n",
    "                \"</tr><tr>\".join(\n",
    "                    '<td style=\"text-align:left;\">{}</td>'.format(\n",
    "                        '</td><td style=\"text-align:left;\">'.join(str(_) for _ in row)\n",
    "                    )\n",
    "                    for row in data\n",
    "                )\n",
    "            )\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "hp_defs = [[\"HyperParameter Name\", \"Type\", \"DefaultValue\", \"Required\", \"Description\"]]\n",
    "\n",
    "for ch in est.hyperparameter_definitions:\n",
    "    hp_defs.append(\n",
    "        [\n",
    "            ch[\"Name\"],\n",
    "            ch[\"Type\"],\n",
    "            est.hyperparameters.get(ch[\"Name\"], ch[\"DefaultValue\"]),\n",
    "            ch[\"Required\"],\n",
    "            ch[\"Description\"],\n",
    "        ]\n",
    "    )\n",
    "\n",
    "table_display(hp_defs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26c3737d",
   "metadata": {},
   "source": [
    "通过`est.hyperparameters`属性和`est.set_hyperparameters` 方法，用户可以查看当前算法配置使用的超参，以及修改超参配置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7bb8e58",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(est.hyperparameters)\n",
    "\n",
    "# 配置算法超参\n",
    "est.set_hyperparameters(\n",
    "    learning_rate=2e-04,\n",
    "    batch_size=4,\n",
    ")\n",
    "\n",
    "print(est.hyperparameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9f0cde7",
   "metadata": {},
   "source": [
    "Tips: 使用DyLoRA进行训练\n",
    "\n",
    "\n",
    "当前算法也支持使用[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558)。该论文提出LoRA的rank并不是越高越好，而是需要根据模型、数据集、任务等因素来寻找合适的rank。使用DyLoRA，可以同时在指定的维度(rank)下学习多种rank的LoRA，从而省去了寻找最佳rank的麻烦。\n",
    "\n",
    "```python\n",
    "\n",
    "est.set_hyperparameters(**{\n",
    "\t\"network_module\": \"networks.dylora\",\n",
    "\t\"network_args\": '\"unit=4\"',\n",
    "\t\"network_dim\": 16,\n",
    "})\n",
    "```\n",
    "\n",
    "例如，当使用以上的超参配置时，dim=16、unit=4进行学习时，可以学习和提取4个rank的LoRA，即4、8、12和16。通过在每个提取的模型中生成图像并进行比较，可以选择最佳rank的LoRA。 如果未指定unit，则默认为unit=1。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51751740",
   "metadata": {},
   "source": [
    "## 准备训练数据\n",
    "\n",
    "训练数据需要准备在OSS上，具体的数据格式可参考：[image_folder_structure.md](https://github.com/bmaltais/kohya_ss/blob/master/docs/image_folder_structure.md) 。\n",
    "\n",
    "在训练目录文件夹内需要准备供模型训练的图片和txt标注文件，txt标注文件来注明每张图片对应的prompt文本。当txt标注文件不存在时，模型将使用图片目录名作为作为prompt文本。\n",
    "\n",
    "> 图像标注文件的生成可使用PAI快速开始提供“image-captioning”系列模型。\n",
    "\n",
    "数据示例：\n",
    "\n",
    "```plain\n",
    "├──train_data\n",
    "| ├── 30_dog\n",
    "| | |\n",
    "| | ├── image1.jpg\n",
    "| | ├── image1.txt\n",
    "| | ├── image2.png\n",
    "| | └── image2.txt\n",
    "| | └── ...\n",
    "```\n",
    "\n",
    "`imgxx.txt`的文件内容如下：\n",
    "\n",
    "```text\n",
    "a dog is running\n",
    "```\n",
    "\n",
    "\n",
    "PAI提供了一份简单的训练数据示例，可以直接用于kohya_lora_trainer的训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e28dc74",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = (\n",
    "    \"oss://pai-quickstart-{region}/aigclib/datasets/kohya_ss/lora/v0.1.0/train/\".format(\n",
    "        region=sess.region_id\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88fa82f4",
   "metadata": {},
   "source": [
    "用户可以通过以下的代码，下载PAI提供训练数据到本地查看。\n",
    "\n",
    "```python\n",
    "\n",
    "from pai.common.oss_util import download\n",
    "download(train_data, \"./train_data\")\n",
    "\n",
    "```\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ee2e100-867c-4c9b-83cc-6142ff3c25e5",
   "metadata": {},
   "source": [
    "## 提交训练作业\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4db6911d",
   "metadata": {},
   "source": [
    "`kohya_lora_trainer` 算法要求提供的输入：\n",
    "- `pretrained_model` 输入：\n",
    "  \n",
    "预训练模型，相应的目录下应该包含一个StableDiffusion格式的预训练模型，可以是`.safetensors`文件格式或是一个`.ckpt`文件格式，或是HuggingFace Diffusers支持的预训练模型格式。\n",
    "\n",
    "- `train` 输入:\n",
    "\n",
    "训练的输入数据，默认是OSS URI格式(`oss://<Bucket>/path/to/train/data/`)\n",
    "\n",
    "\n",
    "可以通过以下代码查看训练算法的输入输出：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57caedf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "io_defs = [[\"Channel Name\", \"Input/Output\", \"Required\", \"Description\"]]\n",
    "for ch in est.input_channel_definitions:\n",
    "    io_defs.append([ch[\"Name\"], \"Input\", ch[\"Required\"], ch[\"Description\"]])\n",
    "for ch in est.output_channel_definitions:\n",
    "    io_defs.append([ch[\"Name\"], \"Output\", ch[\"Required\"], ch[\"Description\"]])\n",
    "\n",
    "\n",
    "table_display(io_defs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d76b8d9",
   "metadata": {},
   "source": [
    "在本示例中，我们将使用PAI提供的一份简单训练数据集和预训练模型进行训练。 通过 `AlgorithmEstimator.fit()` 方法我们可以提交训练任务。SDK会打印训练作业的控制台URL，用户可以通过该URL查看训练作业的详细信息以及日志。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83fdd767-c573-4914-a11a-b180b671efe0",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "region = sess.region_id\n",
    "# PAI提供的训练数据集和预训练模型(SD1.5)\n",
    "train = (f\"oss://pai-quickstart-{region}/aigclib/datasets/kohya_ss/lora/v0.1.0/train/\",)\n",
    "pretrained_model = f\"oss://pai-quickstart-{region}/aigclib/models/custom_civitai_models/sd1.5/v1-5-pruned.safetensors\"\n",
    "\n",
    "# 提交训练作业\n",
    "est.fit(inputs={\"pretrained_model\": pretrained_model, \"train\": train})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26b8323e",
   "metadata": {},
   "source": [
    "\n",
    "任务结束后，用户可以通过 `AlgorithmEstimator.model_data()` 方法来查看产出模型的地址信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e8e4fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 查看输出的模型地址\n",
    "print(est.model_data())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6a66193",
   "metadata": {},
   "source": [
    "模型默认存储在OSS上，用户可以通过以下代码将模型下载到本地。\n",
    "\n",
    "```python\n",
    "\n",
    "from pai.common.oss_util import download\n",
    "\n",
    "download(est.model_data(), \"./model/)\n",
    "\n",
    "\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
